{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "output shape: torch.Size([1, 64, 24, 24])\noutput shape: torch.Size([1, 192, 12, 12])\noutput shape: torch.Size([1, 480, 6, 6])\noutput shape: torch.Size([1, 832, 3, 3])\noutput shape: torch.Size([1, 1024, 1, 1])\noutput shape: torch.Size([1, 1024])\noutput shape: torch.Size([1, 10])\n"
    }
   ],
   "source": [
    "import time\n",
    "import torch\n",
    "import torch.nn as nn \n",
    "import torch.nn.functional as F \n",
    "import d2lzh as d2l \n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "class Inception(nn.Module):\n",
    "    #c1 - c4为每条线路里的输出通道数\n",
    "    def __init__(self,in_c,c1,c2,c3,c4):\n",
    "        super(Inception,self).__init__()\n",
    "        # 线路１\n",
    "        self.p1_1 = nn.Conv2d(in_c,c1,kernel_size=1)\n",
    "        # 线路２\n",
    "        self.p2_1 = nn.Conv2d(in_c,c2[0],kernel_size=1)\n",
    "        self.p2_2 = nn.Conv2d(c2[0],c2[1],kernel_size=3,padding=1)\n",
    "        #线路３\n",
    "        self.p3_1 = nn.Conv2d(in_c,c3[0],kernel_size=1)\n",
    "        self.p3_2 = nn.Conv2d(c3[0],c3[1],kernel_size=5,padding=2)\n",
    "        #线路４\n",
    "        self.p4_1 = nn.MaxPool2d(kernel_size=3,stride=1,padding=1)\n",
    "        self.p4_2 = nn.Conv2d(in_c,c4,kernel_size=1)\n",
    "\n",
    "    def forward(self,x):\n",
    "        p1 = F.relu(self.p1_1(x))\n",
    "        p2 = F.relu(self.p2_2(F.relu(self.p2_1(x))))\n",
    "        p3 = F.relu(self.p3_2(F.relu(self.p3_1(x))))\n",
    "        p4 = F.relu(self.p4_2(self.p4_1(x)))\n",
    "\n",
    "        return torch.cat((p1,p2,p3,p4),dim=1)\n",
    "\n",
    "b1 = nn.Sequential(\n",
    "    nn.Conv2d(1,64,kernel_size=7,stride=2,padding=3),\n",
    "    nn.ReLU(),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)\n",
    "    )\n",
    "b2 = nn.Sequential(\n",
    "    nn.Conv2d(64,64,kernel_size=1),\n",
    "    nn.Conv2d(64,192,kernel_size=3,padding=1),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2,padding=1) \n",
    ")\n",
    "b3 = nn.Sequential(\n",
    "    Inception(192,64,(96,128),(16,32),32),\n",
    "    Inception(256,128,(128,192),(32,96),64),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)\n",
    ")\n",
    "b4 = nn.Sequential(\n",
    "    Inception(480,192,(96,208),(26,48),64),\n",
    "    Inception(512,160,(112,224),(24,64),64),\n",
    "    Inception(512,128,(128,256),(24,64),64),\n",
    "    Inception(512,112,(144,288),(32,64),64),\n",
    "    Inception(528,256,(160,320),(32,128),128),\n",
    "    nn.MaxPool2d(kernel_size=3,stride=2,padding=1)\n",
    ")\n",
    "b5 = nn.Sequential(\n",
    "    Inception(832,256,(160,320),(32,128),128),\n",
    "    Inception(832,384,(192,384),(48,128),128),\n",
    "    d2l.GlobalAvgPool2d()\n",
    ")\n",
    "\n",
    "net = nn.Sequential(b1,b2,b3,b4,b5,d2l.FlattenLayer(),nn.Linear(1024,10))\n",
    "\n",
    "x = torch.rand(1,1,96,96)\n",
    "for blk in net.children():\n",
    "    x = blk(x)\n",
    "    print(\"output shape:\",x.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "training on  cuda\nepoch 1, loss 1.3262, train acc 0.471,test acc 0.777,time 73.1\nepoch 2, loss 0.2244, train acc 0.834,test acc 0.849,time 71.6\nepoch 3, loss 0.1209, train acc 0.864,test acc 0.866,time 71.9\nepoch 4, loss 0.0795, train acc 0.881,test acc 0.857,time 72.0\nepoch 5, loss 0.0576, train acc 0.891,test acc 0.893,time 71.9\n"
    }
   ],
   "source": [
    "batch_size = 128\n",
    "train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=96)\n",
    "\n",
    "lr,num_epochs = 0.001,5\n",
    "optimizer = torch.optim.Adam(net.parameters(),lr=lr)\n",
    "d2l.train_ch5(net,train_iter,test_iter,batch_size,optimizer,device,num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python37764bitpytorchnotebookconda6e7a8693df0d4d92aca09d521275d23a",
   "display_name": "Python 3.7.7 64-bit ('pytorch_notebook': conda)"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}